%load_ext pretty_jupyter
Deep learning in Tree species classification Tutorial¶
Run in Google Colab |
View on Github |
Download Data |
โ Workflow
- Set up the Dataset
- Create a model
- Train
- Test/Visualize result
- Tune the network
- Save/Depoly your model
- Scale up your model
But first, let's pip/import the necessary libraries:
%pip install lightning gdown
try:
from google.colab import drive
IN_COLAB=True
# Mount the Google Drive at mount
mount='/content/gdrive'
print("Colab: mounting Google drive on ", mount)
drive.mount(mount)
except:
IN_COLAB=False
if IN_COLAB:
print("We're running Colab")
Colab: mounting Google drive on /content/gdrive Mounted at /content/gdrive We're running Colab
# Switch to the directory on the Google Drive that you want to use
import os
drive_root = mount + "/MyDrive/tree_species_classification"
# Change to the directory
print("\nColab: Changing directory to ", drive_root)
%cd $drive_root
Colab: Changing directory to /content/gdrive/MyDrive/tree_species_classification /content/gdrive/MyDrive/tree_species_classification
Create a deep learning model¶
import lightning as L
from PIL import Image
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from matplotlib import pyplot as plt
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.loggers import TensorBoardLogger
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from torch.nn import functional as F
from torch import nn
import geopandas as gpd
from sklearn.model_selection import train_test_split
from PIL import Image
import pandas as pd
from pathlib import Path
from os.path import join
from sklearn.metrics import confusion_matrix
import seaborn as sns
import sklearn
Load Crown Data¶
import gdown
import zipfile
# Google Drive file ID (Extracted from the shared link)
file_id = "1I8Lb3mAlkrUSSmdTyLQPQ52HhsGbF6qX" # Replace this with your actual file ID
# Construct the download URL
url = "https://drive.google.com/uc?id=1I8Lb3mAlkrUSSmdTyLQPQ52HhsGbF6qX"
# Define the output filename
output = "data.zip"
# Download the file
print(f"Downloading {output} from Google Drive...")
gdown.download(url, output, quiet=False)
# Extract the ZIP file
print("Extracting files...")
with zipfile.ZipFile(output, "r") as zip_ref:
zip_ref.extractall()
print("Extraction complete!")
Downloading data.zip from Google Drive...
Extracting files... Extraction complete!
#Load the crown polygons
crowns_df = gpd.read_file('data/tree_crowns_subset.gpkg')
# Map class labels to binary values
label_mapping = {'coniferous': 0, 'deciduous': 1}
crowns_df['label'] = crowns_df['species_type'].map(label_mapping)
#Set data dir
img_dir = 'data/clipped_crowns'
img_fpaths = list(Path(img_dir).glob("*.png"))
#Convert fpaths ls to data frame
img_df = pd.DataFrame(img_fpaths, columns=['fpath'])
img_df['crown_id'] = img_df['fpath'].apply(lambda x: int(x.stem.split(".")[0].split("_")[1]))
#Join with crowns_df
crowns_df = crowns_df.merge(img_df, on='crown_id', how='left')
crowns_df.head(5)
| label | common_name | scientific_name | genus | crown_id | species_type | minx | miny | maxx | maxy | geometry | fpath | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | Balsam fir | Abies balsamea | Pinaceae | 8340 | deciduous | 577189.0365 | 5.093486e+06 | 577192.0568 | 5.093488e+06 | MULTIPOLYGON (((577191.446 5093488.217, 577191... | data/clipped_crowns/crown_8340.png |
| 1 | 1 | Balsam fir | Abies balsamea | Pinaceae | 9399 | deciduous | 576957.3289 | 5.093309e+06 | 576960.9351 | 5.093313e+06 | MULTIPOLYGON (((576958.412 5093313.133, 576958... | data/clipped_crowns/crown_9399.png |
| 2 | 1 | Balsam fir | Abies balsamea | Pinaceae | 2458 | deciduous | 577064.1428 | 5.093336e+06 | 577066.9213 | 5.093339e+06 | MULTIPOLYGON (((577066.056 5093338.765, 577065... | data/clipped_crowns/crown_2458.png |
| 3 | 1 | Balsam fir | Abies balsamea | Pinaceae | 2492 | deciduous | 577052.4109 | 5.093352e+06 | 577054.2873 | 5.093355e+06 | MULTIPOLYGON (((577054.098 5093354.535, 577054... | data/clipped_crowns/crown_2492.png |
| 4 | 1 | Balsam fir | Abies balsamea | Pinaceae | 567 | deciduous | 577186.6727 | 5.093215e+06 | 577191.7753 | 5.093218e+06 | MULTIPOLYGON (((577190.923 5093217.595, 577190... | data/clipped_crowns/crown_567.png |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 595 | 0 | Red maple | Acer rubrum | Sapindaceae | 54 | coniferous | 577088.2029 | 5.093114e+06 | 577093.3687 | 5.093119e+06 | MULTIPOLYGON (((577090.925 5093119.305, 577090... | data/clipped_crowns/crown_54.png |
| 596 | 0 | Red maple | Acer rubrum | Sapindaceae | 1327 | coniferous | 577074.5608 | 5.093307e+06 | 577076.5100 | 5.093309e+06 | MULTIPOLYGON (((577076.149 5093308.862, 577076... | data/clipped_crowns/crown_1327.png |
| 597 | 0 | Red maple | Acer rubrum | Sapindaceae | 6126 | coniferous | 577308.0109 | 5.093633e+06 | 577310.8445 | 5.093635e+06 | MULTIPOLYGON (((577310.509 5093634.769, 577310... | data/clipped_crowns/crown_6126.png |
| 598 | 0 | Red maple | Acer rubrum | Sapindaceae | 5284 | coniferous | 577443.5990 | 5.093582e+06 | 577452.3151 | 5.093589e+06 | MULTIPOLYGON (((577448.862 5093588.134, 577448... | data/clipped_crowns/crown_5284.png |
| 599 | 0 | Red maple | Acer rubrum | Sapindaceae | 6506 | coniferous | 577315.9984 | 5.093473e+06 | 577319.4037 | 5.093477e+06 | MULTIPOLYGON (((577318.737 5093475.773, 577318... | data/clipped_crowns/crown_6506.png |
600 rows ร 12 columns
import seaborn as sns
import matplotlib.pyplot as plt
# Create the count plot with 'label'
ax = sns.countplot(data=crowns_df, x='label', hue='label', palette='viridis', legend=False)
# Add a custom legend
legend_labels = {0: 'Coniferous', 1: 'Deciduous'}
handles = [plt.Rectangle((0, 0), 1, 1, color=ax.patches[i].get_facecolor()) for i in range(len(legend_labels))]
plt.legend(handles, legend_labels.values(), title="Tree Type")
# Set labels and title
plt.xlabel('Label')
plt.ylabel('Count')
plt.title('Distribution of Labels')
plt.show()
Set up the Dataset¶
class TreeCrownDataset(Dataset):
def __init__(self, crowns_df, split, target_res=256, train_augmentations=[]):
self.target_res = target_res
self.split = split
self.crowns_df = crowns_df
self.train_augmentations = train_augmentations
# Create a transform to resize and normalize the crown images
self.transforms = [
transforms.Resize((target_res, target_res)),
transforms.ToTensor(),
]
#Add additional transforms for data augmentation if using train dataset
if self.split == 'train':
self.transforms.extend(self.train_augmentations)
# Build transform pipeline
self.transforms = transforms.Compose(self.transforms)
def __len__(self):
return len(self.crowns_df)
def __getitem__(self, idx):
target_crown = self.crowns_df.iloc[idx]
label = torch.tensor(target_crown['label']).long()
crown_img = Image.open(target_crown['fpath']).convert('RGB')
crown_tensor = self.transforms(crown_img)
crown_id = target_crown['crown_id']
return crown_tensor, label, crown_id
Set up the Lightning Data Module¶
class TreeCrownDataModule(L.LightningDataModule):
def __init__(self, crowns_df, batch_size=32, train_augmentations=[]):
super().__init__()
self.crowns_df = crowns_df
self.batch_size = batch_size
def setup(self, stage=None):
#Split data into three dataframes for train/val/test
train_val_df, self.test_df = train_test_split(self.crowns_df,
test_size=0.15,
random_state=42)
self.train_df, self.val_df = train_test_split(train_val_df,
test_size=0.17,
random_state=42)
#Report dataset sizes
for name, df in [("Train", self.train_df),
("Val", self.val_df),
("Test", self.test_df)]:
print(f"{name} dataset size: {len(df)}",
f"({round(len(df)/len(crowns_df)*100, 0)}%)")
# Instantiate datasets
self.train_dataset = TreeCrownDataset(self.train_df, split='train')
self.val_dataset = TreeCrownDataset(self.val_df, split='val')
self.test_dataset = TreeCrownDataset(self.test_df, split='test')
def train_dataloader(self):
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset,
batch_size=self.batch_size,
shuffle=False)
def predict_dataloader(self):
return DataLoader(self.test_dataset,
batch_size=self.batch_size,
shuffle=False
)
#Set the training data augmentations
train_augmentations = [
transforms.RandomHorizontalFlip(),
transforms.RandomRotation([-90, 90]),
transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0))
]
# Test the datamodule
crowns_datamodule = TreeCrownDataModule(crowns_df, train_augmentations=[])
crowns_datamodule.setup()
# Test loading a sample
sample = crowns_datamodule.train_dataset[0]
print(sample[0].shape)
print(sample[1])
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%) torch.Size([3, 256, 256]) tensor(1)
Set up the model¶
class CNN(L.LightningModule):
def __init__(self, lr, pretrained_weights=True):
super(CNN, self).__init__()
self.model = resnet50(weights=ResNet50_Weights.DEFAULT if pretrained_weights else None) # IMAGENET1K_V2 vs. random init
# Modify the final fc layer of model to output a single value for binary classification
self.model.fc = nn.Linear(self.model.fc.in_features, 1)
#Add sigmoid activation to the end model
self.model = nn.Sequential(self.model, nn.Sigmoid())
self.criterion = nn.BCELoss()
self.lr = lr
self.save_hyperparameters()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y, _ = batch
y_hat = self(x).squeeze()
loss = self.criterion(y_hat, y.float())
self.log('train_loss', loss, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
x, y, _ = batch
y_hat = self(x).squeeze()
loss = self.criterion(y_hat, y.float())
self.log('val_loss', loss, on_epoch=True, on_step=False)
return loss
def predict_step(self, batch, batch_idx):
x, y, id = batch
y_hat = self(x).squeeze()
return y_hat, y, id
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
#Instantiate the model with 1 class (present/absent)
model = CNN(lr=0.0001)
print(model)
#Try passing some data through the model
batch, labels, ids = next(iter(crowns_datamodule.train_dataloader()))
# Pass batch through the model
y_hat = model(batch)
print("\nCrown IDs:\n", ids)
print("\nImage batch shape:\n", batch.shape)
print("\nOutput tensor shape:\n", y_hat.shape)
#View the predicted class probabilities
print("\nPredicted class probabilities:\n",
y_hat.detach().cpu().numpy().squeeze())
CNN(
(model): Sequential(
(0): ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): Bottleneck(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer2): Sequential(
(0): Bottleneck(
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer3): Sequential(
(0): Bottleneck(
(conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(4): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(5): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer4): Sequential(
(0): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=2048, out_features=1, bias=True)
)
(1): Sigmoid()
)
(criterion): BCELoss()
)
Crown IDs:
tensor([2648, 1446, 579, 9226, 5744, 6641, 6426, 6865, 3832, 2560, 5034, 4461,
233, 4119, 7429, 6221, 4200, 1362, 9236, 5770, 2350, 1438, 3407, 5438,
6318, 891, 3223, 2488, 7067, 9233, 6737, 1494])
Image batch shape:
torch.Size([32, 3, 256, 256])
Output tensor shape:
torch.Size([32, 1])
Predicted class probabilities:
[0.49712744 0.4790547 0.5167126 0.45847675 0.47571528 0.50378984
0.4806627 0.47133324 0.46858212 0.4780063 0.4591862 0.49085122
0.5068038 0.49928945 0.48584807 0.505633 0.50510865 0.48676297
0.5686578 0.48190606 0.49865714 0.4582818 0.47904262 0.53434956
0.44840366 0.50194734 0.48886827 0.4564724 0.5279636 0.5000776
0.5013218 0.51534027]
Set up the trainer¶
# put together
crowns_datamodule = TreeCrownDataModule(crowns_df, train_augmentations=[])
crowns_datamodule.setup()
model = CNN(lr=0.0001)
csv_logger = CSVLogger('', name='logs', version=0)
tensorboard_logger = TensorBoardLogger('', name='lightning_logs', version=0)
trainer = L.Trainer(max_epochs=10, logger=[csv_logger, tensorboard_logger], devices=1)
trainer.fit(model, datamodule=crowns_datamodule)
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Visualize training process¶
# Read the logs CSV file after training
logs_df = pd.read_csv(csv_logger.log_dir + '/metrics.csv')
logs_df = logs_df.groupby('epoch').mean() # merge the train and valid rows
logs_df['epoch'] = logs_df.index # because "Epoch" gets turned into the index
logs_df.index.name = '' # to remove the name "Epoch" from the index
# Display the logs
print(logs_df)
step train_loss val_loss epoch
0 13.0 0.606503 0.543630 0
1 27.0 0.362105 0.382762 1
2 41.0 0.210748 0.319108 2
3 55.0 0.104064 0.236043 3
4 69.0 0.061297 0.199528 4
5 83.0 0.046444 0.209622 5
6 97.0 0.028861 0.186452 6
7 111.0 0.039431 0.201574 7
8 125.0 0.020977 0.184163 8
9 139.0 0.015162 0.166470 9
#Plot learning curve
plt.figure(figsize=(10, 6))
plt.plot(logs_df['train_loss'], label='Train Loss')
plt.plot(logs_df['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
Or you can use tensorboard
# Clear any logs from previous runs
%rm -rf ./lightning_logs/
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/
def calc_test_oa():
#Test the model on the test set
out = trainer.predict(model, datamodule=crowns_datamodule, return_predictions=True)
# Separate predictions and targets from output
pred_class_probs = np.concatenate([batch[0] for batch in out])
obs = np.concatenate([batch[1] for batch in out])
ids = np.concatenate([batch[2] for batch in out])
#Convert to obs-pred dataframe
test_df = pd.DataFrame({'obs': obs, 'pred_class_probs': pred_class_probs, 'crown_id': ids})
#Convert class probabilities to binary predictions
test_df['pred_boolean_class'] = (test_df['pred_class_probs'] > 0.5)
#Convert binary predictions to integers
test_df['pred'] = test_df['pred_boolean_class'].astype(int)
#Add a column for correct/incorrect predictions
test_df['correct'] = test_df['obs'] == test_df['pred']
#Join with crowns_df
test_df = test_df.merge(crowns_df, on='crown_id', how='left')
#Calculate overall accuracy using sklearn
overall_acc = sklearn.metrics.accuracy_score(y_true=test_df['obs'], y_pred=test_df['pred'])
#Check how many crowns were classified correctly
n_correct = len(test_df[test_df['correct'] == True])
print(f"Summary: {n_correct} / {len(test_df)} crowns were classified correctly.")
return overall_acc, test_df
overall_acc, test_df = calc_test_oa()
print(f"Overall accuracy: {overall_acc:.2f}")
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Predicting: | | 0/? [00:00<?, ?it/s]
Summary: 78 / 90 crowns were classified correctly. Overall accuracy: 0.87
print(label_mapping)
#Generate a confusion matrix using seaborn
cm = confusion_matrix(y_true=test_df['obs'],
y_pred=test_df['pred'])
#Plot the confusion matrix
classes = ['Coniferous', 'Deciduous']
sns.heatmap(cm, annot=True,
cmap='YlGn',
xticklabels=classes,
yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('Observed')
plt.title('Confusion Matrix')
plt.show()
{'coniferous': 0, 'deciduous': 1}
# Let's view the incorrectly classified crowns
incorrect_df = test_df[test_df['correct'] == False]
#Plot incorrecty classified coniferous/deciduous crowns
for c_type in test_df['species_type'].unique():
print(f"\nIncorrectly classified {c_type} crowns.\n")
# Filter the incorrect crowns by species type
incorrect_type_df = test_df[(test_df['correct'] == False) & (test_df['species_type'] == c_type)]
# Number of images
num_images = len(incorrect_type_df)
# Determine the grid size
grid_size = int(num_images**0.5) + 1
# Create a figure and axes
fig, axes = plt.subplots(grid_size, grid_size, figsize=(15, 15))
# Flatten the axes array for easy iteration
axes = axes.flatten()
# Read the incorrect crown files and plot them
for ax, fpath in zip(axes, incorrect_type_df['fpath']):
img = Image.open(fpath)
ax.imshow(img)
ax.axis('off')
# Hide any remaining empty subplots
for ax in axes[num_images:]:
ax.axis('off')
plt.tight_layout()
plt.show()
Incorrectly classified deciduous crowns.
Incorrectly classified coniferous crowns.
Tune hyperparameters¶
Forget about ML for a second. Imagine you are baking a cookie. You have 3 things you can change about the cookie:
- Sugar type (white, brown, cane)
- Baking time (15 minutes, 30 minutes)
- Cooking temperature (360, 400 degrees)
There are 12 possible variations of cookies you can make. One of them will be the most delicious.
To find out which cookie tastes the best, you need to make all variations and assign a score
- ๐คข
- ๐ค
- ๐
- ๐
This is called a hyperparameter sweep. Your three hyperparameters are sugar, baking time, cooking temperature.
python make_cookie.py --sugar 'white' --baking_time 15 --temperature 400
python make_cookie.py --sugar 'brown' --baking_time 15 --temperature 400
๐๐ฝโโ๏ธ what combination of parameters produces the best performing model?
The definition of "best" depends on the work you are doing. In general, "best" refers to the lowest loss. At Lightning, we tend to think of "best" as the lowest loss for the least amount of time spent training.
If we run this training script with different hyperparameter combinations, it produces different loss curves
test 1: pretrained weigths¶
# put together
crowns_datamodule = TreeCrownDataModule(crowns_df, train_augmentations=[])
crowns_datamodule.setup()
csv_logger = CSVLogger('', name='logs', version=1)
tensorboard_logger = TensorBoardLogger('', name='lightning_logs', version=1)
model = CNN(lr=0.01, pretrained_weights=False)
trainer = L.Trainer(max_epochs=10, logger=[csv_logger, tensorboard_logger], devices=1)
trainer.fit(model, datamodule=crowns_datamodule)
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
overall_acc, test_df = calc_test_oa()
print(f"Overall accuracy: {overall_acc:.2f}")
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Predicting: | | 0/? [00:00<?, ?it/s]
Summary: 65 / 90 crowns were classified correctly. Overall accuracy: 0.72
test 2: different learning rate¶
# put together
model = CNN(lr=0.01)
csv_logger = CSVLogger('', name='logs', version=2)
tensorboard_logger = TensorBoardLogger('', name='lightning_logs', version=2)
trainer = L.Trainer(max_epochs=10, logger=[csv_logger, tensorboard_logger], devices=1)
trainer.fit(model, datamodule=crowns_datamodule)
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
overall_acc, test_df = calc_test_oa()
print(f"Overall accuracy: {overall_acc:.2f}")
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Predicting: | | 0/? [00:00<?, ?it/s]
Summary: 73 / 90 crowns were classified correctly. Overall accuracy: 0.81
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/
Save/Depoly your model¶
trainer.save_checkpoint(filepath=".ckpt/model.ckpt")
model = CNN.load_from_checkpoint(".ckpt/model.ckpt", lr=0.01)
model.freeze()
crowns_datamodule = TreeCrownDataModule(crowns_df, train_augmentations=[])
crowns_datamodule.setup()
test_predictions = trainer.predict(model, datamodule=crowns_datamodule)
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%) Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Predicting: | | 0/? [00:00<?, ?it/s]
overall_acc, test_df = calc_test_oa()
print(f"Overall accuracy: {overall_acc:.2f}")
Train dataset size: 423 (70.0%) Val dataset size: 87 (14.0%) Test dataset size: 90 (15.0%)
Predicting: | | 0/? [00:00<?, ?it/s]
Summary: 73 / 90 crowns were classified correctly. Overall accuracy: 0.81
TorchScript allows you to serialize your models in a way that it can be loaded in non-Python environments. The LightningModule has a handy method to_torchscript() that returns a scripted module which you can save or directly use.
script = model.to_torchscript()
# save for use in production environment
torch.jit.save(script, ".ckpt/model.pt")
# use it
#Try passing some data through the model
batch, labels, ids = next(iter(crowns_datamodule.test_dataloader()))
scripted_module = torch.jit.load(".ckpt/model.pt")
output = scripted_module(batch)
Scale up your model/dataset¶
You can either make all cookies sequentially (which will take you 4.5 hours). Or you can get 12 kitchens and cook them all in parallel, and you'll know in 30 minutes.
If a kitchen is a GPU, then you need 12 GPUs to run each experiment to see which cookie is the best. The power of Lightning is the ability to run sweeps like this on 12 different GPUs (or 1,000 GPUs if you'd like) to get you the best version of a model fast.
Train on GPUs The Trainer will run on all available GPUs by default. Make sure youโre running on a machine with at least one GPU. Thereโs no need to specify any NVIDIA flags as Lightning will do it for you.
from lightning import Trainer
# run on as many GPUs as available by default
trainer = Trainer(accelerator="auto", devices="auto", strategy="auto")
# equivalent to
trainer = Trainer()
# run on one GPU
trainer = Trainer(accelerator="gpu", devices=1)
# run on multiple GPUs
trainer = Trainer(accelerator="gpu", devices=8)
# choose the number of devices automatically
trainer = Trainer(accelerator="gpu", devices="auto")
Train on Slurm Cluster
# train.py
def main(args):
model = CNN(args)
trainer = Trainer(accelerator="gpu", devices=8, num_nodes=4, strategy="ddp")
trainer.fit(model)
if __name__ == "__main__":
args = ... # you can use your CLI parser of choice, or the `LightningCLI` or using config.yaml
# TRAIN
main(args)
%%writefile submit.sh
# (submit.sh)
#!/bin/bash -l
# SLURM SUBMIT SCRIPT
#SBATCH --nodes=4 # This needs to match Trainer(num_nodes=...)
#SBATCH --gres=gpu:8
#SBATCH --ntasks-per-node=8 # This needs to match Trainer(devices=...)
#SBATCH --mem=0
#SBATCH --time=0-02:00:00
# activate conda env
source activate $1
# debugging flags (optional)
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1
# on your cluster you might need these:
# set the network interface
# export NCCL_SOCKET_IFNAME=^docker0,lo
# might need the latest CUDA
# module load NCCL/2.4.7-1-cuda.10.0
# run script from above
srun python3 train.py
%%!
sbatch submit.sh
Or you can even parallel the baking procedure...

wandba sweep¶
import wandb
wandb.login()
True
%%html
<iframe src="https://api.wandb.ai/links/ubc-yuwei-cao/ebnspmv1" style="border:none;height:1024px;width:100%">
Run in Google Colab
View on Github
Download Data